# coding:utf-8

from src.config.oracle_config import OracleConfig
from src.config.stock_trend_prediction_config import StockTrendPredictionConfig
from src.manager.oracle_manager import OracleManager
from src.manager.log_manager import LogManager
from sklearn import datasets
from sklearn.preprocessing import OneHotEncoder
from sklearn.cluster import KMeans
from numpy import unique
from numpy import where
from matplotlib import pyplot
import pandas as pd
import numpy as np

Logger = LogManager.get_logger(__name__)


class StockTrendPredictionHandler:
    """
    股票趋势预测处理器
    """

    def __init__(self) -> None:
        super().__init__()

        self.oracle_manager = OracleManager(OracleConfig.Username, OracleConfig.Password, OracleConfig.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def rotate_row_by_close_price_ma_order(self):
        """
        行列转换
        """
        # 获取股票code
        self.cursor.execute(
            "select distinct t.code_ from mdl_stock_analysis t "
            "where t.date_ between to_date('%s','yyyy-mm-dd') and to_date('%s','yyyy-mm-dd')" % (
                StockTrendPredictionConfig.Prediction_Begin_Date, StockTrendPredictionConfig.Prediction_End_Date))
        distinct_code_list = self.cursor.fetchall()

        # 每只股票的close_price_ma_order字段
        data_list = list()
        code_list = list()
        name_list = list()
        for i, code in enumerate(distinct_code_list):
            self.cursor.execute("select * from ( "
                                "select t.id_, t.date_, t.code_, t.close_price_ma_order, si_.name_ "
                                "from mdl_stock_analysis t "
                                "join stock_info si_ on si_.code_=t.code_ "
                                "where t.code_='%s' "
                                "and t.date_ between to_date('%s','yyyy-mm-dd') and to_date('%s','yyyy-mm-dd') "
                                "order by t.date_ desc) where rownum<=%s" % (
                                    code[0], StockTrendPredictionConfig.Prediction_Begin_Date,
                                    StockTrendPredictionConfig.Prediction_End_Date,
                                    StockTrendPredictionConfig.Date_Number))
            close_price_ma_order_list = self.cursor.fetchall()
            row_list = list()
            for j, close_price_ma_order in enumerate(close_price_ma_order_list):
                row_list.append(close_price_ma_order[3])
            data_list.append(row_list)
            code_list.append(close_price_ma_order_list[0][2])
            name_list.append(close_price_ma_order_list[0][4])
        return data_list, code_list, name_list

    def predict_stock_trend(self):
        """
        使用K-Means聚类算法，预测股票的趋势
        """
        # 获取数据
        data_list, code_list, name_list = self.rotate_row_by_close_price_ma_order()

        # 预测
        ohe = OneHotEncoder()
        X = pd.DataFrame(data_list)
        # data_array = np.array(data_list).reshape(1, -1)
        transform_data = ohe.fit_transform(X)
        clf = KMeans(n_clusters=3)
        model = clf.fit(transform_data)
        predicted = model.predict(transform_data)
        # print("预测值", predicted)
        # print("真实值", code_list)
        prediction_list = predicted.tolist()
        for index, value in enumerate(data_list):
            Logger.info(
                "预测值：" + str(prediction_list[index]) + "，股票代码：" + code_list[index] + "，股票名称：" + name_list[index])

        ###################################### 下面的先不执行
        # iris = datasets.load_iris()
        # X = iris.data
        # _ = iris.target
        # 定义模型
        model = KMeans(n_clusters=3)
        # 模型拟合
        model.fit(transform_data)
        # 为每个示例分配一个集群
        yhat = model.predict(transform_data)
        # 检索唯一群集
        clusters = unique(yhat)
        # 为每个群集的样本创建散点图
        for cluster in clusters:
            # 获取此群集的示例的行索引
            row_ix = where(yhat == cluster)
            # 创建这些样本的散布
            pyplot.scatter(X[row_ix, 0], X[row_ix, 1])
        # 绘制散点图
        pyplot.show()
        print("预测值", yhat)
        print("真实值", code_list)
